-
-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Fix: Resolve torch_dtype correctly for AMP #3189
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Important Review skippedAuto incremental reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the You can disable this status message by setting the 📝 WalkthroughWalkthroughAdjusts resolve_dtype logic in src/axolotl/utils/config/init.py to default mixed-precision requests (fp16/bf16) to load the base model in torch.float32, with a specific branch setting torch_dtype to bf16, and narrows conditions that set torch_dtype to float16 by removing fp16 from that path. Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
Thanks for the PR and figuring it out! I think for a lot of people, they would prefer bf16 master weights to save VRAM. Could an alternative solution be, creating a new config: We can then add this to the FAQ to enable this config for fp16 if getting gradscaler error. |
I don't agree with this approach, I think we should respect the user's choice of |
Codecov Report✅ All modified and coverable lines are covered by tests. 📢 Thoughts on this report? Let us know! |
|
Appreciate for the replies! I think there might be some confusion stemming from the current config schema, which I also found challenging. Let me try to clarify my understanding of the issue. As I understand it, mixed precision has a specific technical meaning. It uses fast half-precision (FP16/BF16) for compute while maintaining a high-precision FP32 master copy of the weights for the optimizer. Hence this "mix" provides the speed of half-precision with the stability of FP32. When a engineer reads the axolotl docpage about enabling mixed precision, this is what they'll assume is happening when they toggle This is distinct from "pure" half-precision training, where the model's weights, gradients, and optimizer states are all in BF16 (or FP16, which is generally unstable). This saves more VRAM but can sacrifice numerical stability. In terms of how the configs appear to be intended to work in Axolotl, I understand it like this:
The issue my PR addresses is that the original logic conflated these two distinct modes. When a user specified This is the current scenario on main:
* crashes
The existing config schema already provides a way to train in bf16 with bf16 master weights: through setting This whole discussion highlights that the current flags ( But I'm aware this might be beyond the scope of this specific PR, which just aims to make behaviour under the current schema stable and correct by aligning it with the standard PyTorch AMP implementation. |
The previous logic in `resolve_dtype` incorrectly configured the model's `torch_dtype` for Automatic Mixed Precision (AMP) training. When `fp16: true` or `bf16: true` was set, the model was loaded directly into a half-precision format. This conflicts with the standard PyTorch AMP workflow, which expects the model to be loaded in FP32 to establish master weights before being managed by the `autocast` context and `GradScaler`. This misconfiguration led to `GradScaler` failures with FP16 and an inefficient, non-standard AMP implementation for BF16. This commit adjusts the logic to prioritize the AMP flags. If `fp16` or `bf16` is enabled, `torch_dtype` is now correctly resolved to `torch.float32`. The logic for pure precision modes (`float16`, `bfloat16`) remains.
|
Ok, after some internal discussion, I'm good with this PR now. My next thought would whether to convert existing example yamls to use |
Description
The previous logic in
resolve_dtypeincorrectly configured the model'storch_dtypefor Automatic Mixed Precision (AMP) training.When
fp16: trueorbf16: truewas set for mixed precision, the model was loaded directly into a half-precision format. This conflicts with the standard PyTorch AMP workflow, which expects the model to be loaded in FP32 to establish master weights before being managed by theautocastcontext andGradScaler.This misconfiguration led to
GradScalerfailures with FP16 (expecting FP32 weights) and an inefficient, non-standard AMP implementation for BF16.This commit adjusts the logic to prioritize the AMP flags. If
fp16orbf16is enabled,torch_dtypeis now correctly resolved totorch.float32. The logic for pure precision modes (float16,bfloat16) remains.Motivation and Context
My hardware (AMD MI100) has a 2x faster theoretical throughput for FP16 compared with BF16, so I was interested in trying FP16 mixed precision despite the reduction in stability.
I initially observed failures when attempting to use FP16 mixed-precision by toggling
fp16: true. Doing so on an example config would result in an error like the following:traceback
This error was consistent across a number of different configuration changes (flash attention, xformers, gradient accumulation, sample packing, etc.).
BF16 AMP would run without error, as it avoided the gradient scaling pathway.
A simple torch reproducer with FP16 AMP worked without error.
torch reprod
As did a HF Trainer reproducer, so I knew the issue was with axolotl.
hf trainer reprod
How has this been tested?
This change was validated on an AMD MI100 GPU (ROCm backend), where the original code consistently failed with a
ValueError: Attempting to unscale FP16 gradients.After applying this patch, FP16 AMP training now runs successfully, with a significant performance improvement over both FP32 as measured with a custom nanoGPT-style config.
Flash attention 2 (only supporting FP16/BF16) can be enabled and provides a performance boost, implying that AMP is active and the attention passes are done in lower precision as expected.
The fix should be backend-agnostic.
Example config to reproduce crash:
Types of changes
Bugfix
Summary by CodeRabbit